{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Callbacks\n",
    "\n",
    "author: Jacob Schreiber <br>\n",
    "contact: jmschreiber91@gmail.com\n",
    "\n",
    "It is sometimes convenient to be able to implement custom code at certain points in the training process. A \"callback\" is one way of doing this. Essentially, a callback is an object that has certain methods implemented. When the object is passed in to the `fit` method of a pomegranate object, these methods will be automatically called at predetermined points during training. These methods an implement a wide variety of functionality, but a few common callbacks include model checkpointing, where the model is written out to disk after each epoch, early stopping, where the training of a model stops early based on performance on a validation set, and even TensorBoard, which displays the results of training of multiple models.\n",
    "\n",
    "Callbacks are implemented in pomegranate using a similar approach to that of keras. The base callback looks like the following: \n",
    "\n",
    "```python\n",
    "class Callback(object):\n",
    "\t\"\"\"An object that adds functionality during training.\n",
    "\n",
    "\tA callback is a function or group of functions that can be executed during\n",
    "\tthe training process for any of pomegranate's models that have iterative\n",
    "\ttraining procedures. A callback can be called at three stages-- the\n",
    "\tbeginning of training, at the end of each epoch (or iteration), and at\n",
    "\tthe end of training. Users can define any functions that they wish in\n",
    "\tthe corresponding functions.\n",
    "\t\"\"\"\n",
    "\n",
    "\tdef __init__(self):\n",
    "\t\tself.model = None\n",
    "\t\tself.params = None\n",
    "\n",
    "\tdef on_training_begin(self):\n",
    "\t\t\"\"\"Functionality to add to the beginning of training.\n",
    "\n",
    "\t\tThis method will be called at the beginning of each model's training\n",
    "\t\tprocedure.\n",
    "\t\t\"\"\"\n",
    "\n",
    "\t\tpass\n",
    "\n",
    "\tdef on_training_end(self, logs):\n",
    "\t\t\"\"\"Functionality to add to the end of training.\n",
    "\n",
    "\t\tThis method will be called at the end of each model's training\n",
    "\t\tprocedure.\n",
    "\t\t\"\"\"\n",
    "\n",
    "\t\tpass\n",
    "\n",
    "\tdef on_epoch_end(self, logs):\n",
    "\t\t\"\"\"Functionality to add to the end of each epoch.\n",
    "\n",
    "\t\tThis method will be called at the end of each epoch during the model's\n",
    "\t\titerative training procedure.\n",
    "\t\t\"\"\"\n",
    "\n",
    "\t\tpass\n",
    "```\n",
    "\n",
    "During the training process the `self.model` attribute gets set to the model that is being trained, allowing users to interact with it and use the methods. \n",
    "\n",
    "A user can define a custom callback by simply by inheriting from this object (in pomegranate.callbacks) and implementing the methods that they care about. This doesn't have to be all of the methods.\n",
    "\n",
    "There are a few callbacks that are built-in to pomegranate:\n",
    "1. ModelCheckpoint: This callback will save a copy of the model every few epochs in case the user is performing a long-running job.\n",
    "2. History: This callback will save the logs generated during training and return them in a convenient object.\n",
    "3. CSVLogger: This callback is similar to History except that it will save the logs to a given CSV file\n",
    "4. LambdaCallback: This callback allows one to pass in function objects for each of the three methods to execute rather than define ones own custom callback."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sun Dec 01 2019 \n",
      "\n",
      "numpy 1.17.2\n",
      "scipy 1.3.1\n",
      "pomegranate 0.11.3\n",
      "\n",
      "compiler   : GCC 7.3.0\n",
      "system     : Linux\n",
      "release    : 4.15.0-66-generic\n",
      "machine    : x86_64\n",
      "processor  : x86_64\n",
      "CPU cores  : 8\n",
      "interpreter: 64bit\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import numpy\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn; seaborn.set_style('whitegrid')\n",
    "\n",
    "from pomegranate import *\n",
    "\n",
    "numpy.random.seed(0)\n",
    "numpy.set_printoptions(suppress=True)\n",
    "\n",
    "%load_ext watermark\n",
    "%watermark -m -n -p numpy,scipy,pomegranate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using a callback\n",
    "\n",
    "Let's first take a look at how to use the built-in callbacks. We'll start off with the History callback, which is already automatically created and updated during training. You can return the history object with the data stored in it using the `return_history` parameter during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = numpy.random.randn(10000, 13)\n",
    "\n",
    "d1 = MultivariateGaussianDistribution(numpy.zeros(13), numpy.eye(13))\n",
    "d2 = MultivariateGaussianDistribution(numpy.ones(13), numpy.eye(13))\n",
    "\n",
    "model = GeneralMixtureModel([d1, d2])\n",
    "_, history = model.fit(X, return_history=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After training we can use the history object to make several useful plots. An intuitive plot is the log probability of the data set given the model $P(D|M)$ over the number of epochs of training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0, 0.5, 'Log Probability')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaMAAAEJCAYAAAA5Ekh8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deXwV1fnH8U9IQkIIYd8DAoqPQUBZRFutC9q61q3SKmqt1dZa+9NqW60/22oXa+0mWpcuKtXaitRqtbVq3aj+rIoEUNHwCLKvAUJIAuRmu78/ZoLXmA3IzdyQ7/v1yit3zpmZ+8x5XfJwZs49Jy0ejyMiIhKlLlEHICIiomQkIiKRUzISEZHIKRmJiEjklIxERCRyGVEH0BEtXLgwnpWV1er9Y7EYu7N/Z6Q2ap7ap3lqn5alQhvt2LFj86RJk/o3VqdktAeysrIoKCho9f5FRUW7tX9npDZqntqneWqflqVCGxUWFq5sqk636UREJHJKRiIiEjklIxERiZySkYiIRE7JSEREIqdkJCIikVMyEhGRyOl7RiIinVysppYdsVoqYjXsqKple1UNO2Lh76qgbGf488kD+jFpv95tHoOSkYhIB1ZbF6e8spptO6sp3RH8Lgu3yytrKAu312wsocsbFVRU1lBWWc32qhoqKmvYHqulqrau1e9XurNayUhEZF8Wj8cpj9WwpaKKku0xNldUsXV7FVu2B79LdgS/t+6opnRH8Lussprm1kjN6JJGj+wMsrrE6ZvXhdysDPJ7d6NHdia5WRnkZmfQvWs63bMygp+uGXTPSienawY5XdPJCeu6dU0nJzOdjPTkPN1RMhIRSbLq2jo2lcfYWFbJxrIYxeWVbCqPUVwWY1NFjM0VMTaXx9i8vYqqmsZ7KTld0+md05Xe3TPpndOV4X1y6JWTSa+crvTqlknP+p+c4HdediZ53TLolplOWlpaSkwH1BwlIxGRvVBXF6e4PMba0h2sLa1kXelO1pXuZP22SjZsq2RDWSWbK2If6710SYO+uVn0z82iX48sDhiQS//cLPrmdqVfbhZ9c7Po270rfcKf7Mz0aC6wnSgZiYi0oCJWw4rN21lVsoNVJTtYuWUHq0t2sGbrDtaW7qS69qOZJi87gyG9ujGoZzZjh+YxMC+bQXnZDMzLpn+PLAbkZdG3exbpXdIiuqLUo2QkIkJwK23llu0s27SdZZu3s3zTdpZtrmD55h1sroh9ZN8+3bsyrHc3xg7tyUljB5PfuxtDe3cjv1c3BvfqRm6W/rTuLrWYiHQqO6tq+WBTBUuKy1mysYKlxRUs3VTBqi07qKn7sIfTL7crI/t15zjrz4h+3RnZrzvD++QwvG8OedmZEV7BvknJSET2SbV1cZZv3k7R+jLe31jO4g3lvL+xnFUlO3Y9v8noksZ+fXMYPSCXk8cOYlS/XPYfkMvIft3p2U0Jpz0pGYlIhxerqWPh6lIWrd3Gu+u28e66MnxDObFwZFp6lzRG9uvO2CE9OWvCUA4c2IPRA3IZ0a87mUkaqiy7R8lIRDqUyupaFm8o552123hnTSnvrC3DN5RRF18BQM9umYwZnMcFR+xHweA8Cgb3YP/+ufv8aLSOLiWSkZlNA24CCoAp7j4vLM8E7gUmEsT6oLvfknBcOjAPWOvup4VlI4FZQB9gPnChu1eZWRbwIDAJ2AJ8wd1XhMdcD1wC1AJXuvuzyb5mEWlePB4MmV68oRzfUEbR+nLeXbeNDzZtpzZ8ttOne1fGDu3J+H69OO6QURw8pCf5vbuRlqZRah1NSiQjYBFwNvC7BuXTgCx3H2dmOcB7ZvZwfRIBrgKKgLyEY24FbnP3WWb2W4Ikc0/4e6u7H2Bm54b7fcHMxgDnAgcDQ4DnzexAd69NypWKyEfE43E2lFWytLiCJRsrWFJcwdLicpYUV1C6o3rXfgPzsjh4SE9OOngQY4b0ZOzQPIb26pbwhc7BEV6F7K2USEbuXgRgZg2r4kB3M8sAugFVQFm4bz5wKnAzcE1YlgZMBaaHxz9A0OO6BzgjfA3wKHBnuP8ZwCx3jwHLzWwpMAV4rY0vU6RTi8fjrNtWyfsby1myMRjJ9n5xBR8UV1ARq9m1X6+cTA4c0IOTxw7moEE9sEE9sIE96N29a4TRS7KlRDJqxqMEyWI9kANc7e4lYd0M4FqgR8L+fYFSd6//ZK8BhoavhwKrAdy9xsy2hfsPBV5POEfiMY2KxWIUFRW1+iIqKyt3a//OSG3UvI7WPjuq61i+tYrlJTGWb61ixdYqVpRWsaP6w6HTvbulM7xnJseNzGF4z0yG9ezK8F6Z9MpOT7jNthNiO9mwqpgNzbxfR2ufKKR6G7VbMjKz54FBjVTd4O5PNHHYFILnOEOA3sAr4XnGAMXuXmhmxybs39iN4ngLdc0d06isrKzdmuMp1eeESgVqo+alcvts21nNorXbggEFa7fx7tptrNiyY1d9z26Z2KAefG7/QUEvZ1Awkq1XTtv1dFK5fVJFKrRRYWFhk3Xtlozc/YQ9OGw68Iy7VwPFZvYqMBmYAJxuZqcA2UCemT0EXAj0MrOMsHeUD6wLz7UGGAasCW/79QRKEsrrJR4jIglq6+Is3lDG/JVbWbC6lIWrS1m2afuu+vze3Rg7pCefm5jPmCF5FAzOY3DPbA0okBal+m26VcDUMNHkAEcAM9x9NnA9QNgz+ra7XxBuvwScQzCi7iKgvtf1ZLj9Wlj/orvHzexJ4C9m9muCHthoYG77XJ5IaqusrmXh6lLeWFbCvJUlLFhVuuv5Tr/crhw6rBdnTxjK+PxejBvaU891ZI+lRDIys7OA3wD9gafMbKG7nwjcBcwkGG2XBsx097dbON11wCwz+wmwALgvLL8P+FM4QKGEYAQd7v6umc0G3gNqgCs0kk46q5raOt5aU8orSzbz2gdbWLC6lKqaOtLSwAb24MwJQ5i8Xx8m7ddbQ6ilTaVEMnL3x4HHGymvIBje3dyxc4A5CdvLCJ41NdyvsqlzufvNBKPyRDqdNVt38JJv4j++iTeWbaE8VkNaGhw8JI8vHrEfR4zqy2Ej+tAzR9PjSPKkRDISkfZTWxdn/qqtPPfeRl5aXMyS4gogeN5z2iFD+NTofnxiVF/dcpN2pWQk0glU1dTx6tLNPLNoA88XbWTL9ioy09M4fGRfvnDYMI47aACj+nXXbTeJjJKRyD6quraO/1uymX++vZ7n3ttAWWUNPbIyOO6gAXx6zECOtf700FIIkiKUjET2IfF4nAWrS/n7grX88+31lGyvokd2Bp8ZM4hTxw/iqAP60zVDs1RL6lEyEtkHbNhWyd/mr+HRwjUs37ydrIwunDBmIGceOpSjD+xHVoZmrJbUpmQk0kHV1NbxwuJiZs1dxX/e30RdHA4f2YfLj9mfk8YN0mqk0qEoGYl0MBvLKpk1dzWz3lzF+m2VDMzL4vJj92fapGGM6Nc96vBE9oiSkUgHMX/VVma+uoKn31lPTV2cT43ux02nH8zxBw0gQ6uVSgenZCSSwmpq6/jXog3c9dxafPMyemRlcNEnR3DhEfupFyT7FCUjkRS0PVbDI2+u5r7/W87a0p0Mzcvkh6cfzOcm5ZObpX+2su/Rp1okhWypiDHz1RU8+NoKyiprmDKiDzedfjBD2MLBY0ZEHZ5I0igZiaSAdaU7+f3Ly5j15ipiNXWcOGYQlx0zignDewNQVFTSwhlEOjYlI5EIrdyynbtf+oC/zV8DwJkThvK1Y/bngAG5EUcm0r6UjEQi8MGmCu56cSlPvLWO9C5pTD98OF89ehT5vXOiDk0kEkpGIu1oaXE5d7ywlH+8vY7sjHQu/uQIvnr0KAbkZUcdmkiklIxE2sH7G8u544UlPPXOerplpnPZ0fvzlU+NpG9uVtShiaQEJSORJFqysZzbwySUk5nO147Zn698ahR9tFaQyEcoGYkkwdLiCm5/YQn/fHsdOZnpXH7M/lyqJCTSJCUjkTa0bFMFd7ywhCffWke2ekIiraZkJNIGVpfs4PYXlvDY/DVkZaTzlaNH8dVPjdIzIZFWUjIS2QvrSnfymxeX8td5q0nvksbFR47ka8fsT/8eSkIiu0PJSGQPbK6IcddLS/nz66uIE2f64cO54rgDGKgh2iJ7RMlIZDds21nNH15exv2vLidWU8fnJg7lyuNH68uqInsp8mRkZtOAm4ACYIq7zwvLM4F7gYkEcT7o7rckHJcOzAPWuvtpYdlIYBbQB5gPXOjuVWZ2DXApUANsAr7s7ivDYy4Cvhee9ifu/kByr1g6osrqWh747wrunvMB23ZWc9r4wVz96QPZv7+m7RFpC6mwItci4Gzg5Qbl04Asdx8HTAIuM7MRCfVXAUUNjrkVuM3dRwNbgUvC8gXAZHcfDzwK/BzAzPoANwKHA1OAG82sdxtdl+wDamrreOTNVRz3yznc8vRiJgzvxVNXHsWd0ycqEYm0och7Ru5eBGBmDaviQHczywC6AVVAWbhvPnAqcDNwTViWBkwFpofHP0DQ47rH3V9KOO/rwAXh6xOB59y9JDzHc8BJwMNtdoHSIcXjcV5cXMzPnl7MkuIKDh3Wi9u+cChHjOobdWgi+6TIk1EzHgXOANYDOcDV9UkDmAFcC/RI2L8vUOruNeH2GmBoI+e9BHg6fD0UWJ1Q19QxHxGLxSgqatgpa1plZeVu7d8ZpVIbvb85xn2FW3h7QyVD8zL53rED+eTwHNJixRQVFUcSUyq1TypS+7Qs1duoXZKRmT0PDGqk6gZ3f6KJw6YAtcAQoDfwSnieMUCxuxea2bEJ+6c1co54gzguACYDx7T2mMZkZWVRUFDQ0m67FBUV7db+nVEqtNH6bTv5+TPO4wvW0rd7V358xsGcO2U4menR381OhfZJZWqflqVCGxUWFjZZ1y7JyN1P2IPDpgPPuHs1UGxmrxIkkgnA6WZ2CpAN5JnZQ8CFQC8zywh7R/nAuvqTmdkJwA3AMe4eC4vXAMcmvGc+MGcPYpUObGdVLb97+QN++58PqIvD5cfuz9eP3Z8e2ZlRhybSaaTybbpVwNQw0eQARwAz3H02cD1A2DP6trtfEG6/BJxDMKLuIuCJsHwC8DvgJHdPvM/yLPDThEELn6k/t+z74vE4T72znp8+VcS6bZWcOn4w3z3pIIb10TBtkfYWeTIys7OA3wD9gafMbKG7nwjcBcwkGG2XBsx097dbON11wCwz+wnBCLr7wvJfALnAX8OBEqvc/XR3LzGzHwNvhvv9KOG5lOzDfEM5Nz35Lq8t28KYwXnMOHcCU0b2iToskU4r8mTk7o8DjzdSXkEwvLu5Y+eQcFvN3ZcRPGtquF+Ttwnd/X7g/lYHLB3a9lgNM55/n/tfXUGP7Ax+cuZYzpsynPQujT0+FJH2EnkyEmkP8XicZ9/dyA//8S7rt1Vy3pThXHui0VuzaYukBCUj2eetK93J9/++iBcWF3PQoB7cOX0ik/bTd5tFUomSkeyz6uri/HnuKm59ejG1dXFuOKWAi48cQUYKDNUWkY9SMpJ90vLN27nu0beZu6KEIw/oyy1njWd4X42SE0lVSkayT6mri/PQGyu55V+LyUxP4+fnjGfapHzS0jRAQSSVtSoZmVm6u9cmOxiRvbGudCfXPvo2/7d0M8cc2J9bPzeeQT21vpBIR9DantF6M3sY+FP9Eg8iqeSJhWv53t8XUVsX56dnjeO8KcPUGxLpQFqbjE4Gzgf+YWalwJ+Ah9x9VdIiE2mF8spqfvDEuzy+YC0Th/dixhcm6NmQSAfUqmTk7oVAoZl9m2DKnAuAd8xsPkFiesTdtycvTJGPm79qK1fNWsDarTu56vjR/M/UAzRSTqSD2q1/ue5eBywOfzYRLLdwPrDazC5s+/BEPq6uLs49cz5g2m9fo64OZl/2Ca7+9IFKRCIdWGsHMPQGPk8wM3YBMBv4orv/N6w/DPg3QS9JJGk2V8S4ZvZbvPz+Jk4dN5ifnj2Ont00u7ZIR9faZ0ZrgJeAO4AnEpZgAMDd3zSzptYlEmkT/126maseWUjZzmpuPmss06cM1yAFkX1Ea5PRAe6+vmGhmQ1y9w0A7v6ltgxMpF5tXZw7XljCHS8uYVS/7vzpkikcNCgv6rBEpA21Nhk50Ni//vcAzbsvSVNcVslVsxby2rItfG5iPj8+82Byuuq72iL7mtb+q/7YvRAzywPq2jYckQ+9smQTVz+ykO2xWn5xznimTR4WdUgikiTNJiMzWw3EgW5m1vA7RX2Bh5MVmHReNbV1zHh+CXfNWcroAbk8/JWJjB7YI+qwRCSJWuoZXUDQK/oXwUi6enFgo7t7sgKTzmnDtkqunLWAuctL+PzkfH54+li6dU2POiwRSbJmk5G7/wfAzPq5+472CUk6qzlezDWz36KyupbbvnAIZ03IjzokEWknTSYjM7vB3W8ON79rZo3u5+4/SEZg0nlU19Yxs7CE2YuW7Vr87oABuVGHJSLtqLmeUeJ/S/XkWJJiXelOrnx4AfNWlnLelOHc+NkxZGfqtpxIZ9NkMnL3yxNeX9w+4Uhn8pIXc80jC6mqqeO6owdw+Snjog5JRCLS3G26Ua05gbsva7twpDOoqa3j18+9z91zPuCgQT24+/yJxDavjjosEYlQc7fplhKMmmtuvpU4oHsq0mrFZZV84+FgtNx5U4Zx42cPJjsznaLNUUcmIlFq7jZdu02BbGbTgJsIJmGdUr+An5llAvcCEwlifdDdb0k4Lh2YB6x199PCspHALIKZIeYDF7p7VcIx5wB/BQ5LeJ/rgUuAWuBKd382qRfcSb2+bAvf+MsCtsdqNFpORD4iVebcXwScDbzcoHwakOXu44BJwGVmNiKh/iqgqMExtwK3uftoYCtBkgHAzHoAVwJvJJSNAc4FDgZOAu4Ok5y0kXg8zu/+8wHn3/sGedkZ/P2KI5WIROQjmntm9Iy7nxS+foXgltzHuPvRexuEuxeF79OwKg50N7MMoBtQBZSF++YDpwI3A9eEZWnAVGB6ePwDBD2ue8LtHwM/B76d8B5nALPCmciXm9lSYArw2t5el0BFrIbv/PUtnl60gVPGDeLWz42nR7aWfBCRj2rumdGDCa/vTXYgTXiUIFmsB3KAq929JKybAVwLJM4T0xcodfeacHsNwQKAmNkEYJi7/zNcsbbeUOD1hO1dxzQlFotRVNSwQ9a0ysrK3dp/X7G2rJofvbiBNWXVXDq5D2eP6caa5Usb3beztlFrqX2ap/ZpWaq3UXPPjP6S8PqBvX0jM3seGNRI1Q3u3tRaSFMInuMMAXoDr4TnGQMUu3uhmR2bsH9jgy3iZtYFuA34UiP1jR7TRDwAZGVlUVBQ0NwuH1FUVLRb++8LXly8kaufXkhGlzT+dMnhHHlAv2b374xttDvUPs1T+7QsFdqosLCwybpWz8VvZl8GziNIDOsIBgnc7+7N/uGu5+4ntPa9EkwHnnH3aqDYzF4FJgMTgNPN7BQgG8gzs4cI5s/rZWYZYe8oP4y1BzAWmBPeChwEPGlmpxP0hBK/1Ft/jOyBeDzO715exq3PLKZgUB6/u3ASw/rkRB2WiKS41i47/nOC22UzgJXAcILnLkZwqyxZVgFTw0STAxwBzHD32cD1YWzHAt929wvC7ZeAcwiS5UUEK9NuA3b919zM5oTHzDOzncBfzOzXBIl2NDA3ide0z4rV1PK/jy3ib/PXcOr4wfzynEM0yamItEpre0ZfAia6+5r6AjN7imDo9F4nIzM7C/gN0B94yswWuvuJwF3ATILRdmnATHd/u4XTXQfMMrOfAAuA+5rb2d3fNbPZBAsF1gBXuHvtXl1QJ7SlIsZlfypk3sqtfPOE0Vx1/GgtCS4irdbaZFQe/jQsK2uLINz9ceDxRsorCIZ3N3fsHGBOwvYygmdNzR1zbIPtmwlG5ckeWLapgotmzqW4LMZvzpvAZw8ZEnVIItLBtHY6oBnAY2b2Mz58xvIdgkEB0okVrizh0gfm0SUtjVlfPYIJw3tHHZKIdEC7Ox3QcQ32mQrc2dZBScfwzKINXDVrAYN7ZvPHi6cwol/3qEMSkQ4qJaYDko7nL2+s4oa/v8Mh+b2476LJ9M3NijokEenAWj20W6Te7/7zAbc8vZjjrD93nz9JI+ZEZK+1dmh3BvB14BiCIdK7bt21xXRA0jHE43F+9e/3ufOlpZw2fjC//vyhdM1QB1pE9l5r/5LcBlxGMJHpJOBvwADgxSTFJSkmHo/zw3+8x50vLeXcw4Zx+7kTlIhEpM209q/J2cDJ7n47UBP+PpOPD2iQfVA8HufGJ9/lj/9dwaVHjeSWs8eR3kXfIRKRttPaZJQD1C/FudPMctx9McG0PLIPq+8RPfjaSr569ChuOLVAX2YVkTbX2gEMRcBhBNPkzANuMrMyYG2yApPoxeNxfvTP93b1iK4/+SAlIhFJitYmo6sIZs+GYO2gewgmH/1qMoKS6MXjcW55ejEzX13Bl48cqR6RiCRVq5KRu7+Z8HoJsCczcEsHcvecD/j9y8v44if24/unKRGJSHLtzhISU2mwhIS7v5CswCQ6D72+kl8865x56BBu+uzBSkQiknStGsBgZtcQLMlQAjwFbCFYduFbSYxNIvDkW+v4/hOLOP6gAfxi2iF00ag5EWkHre0ZfQuY6u6L6gvM7E/Ac8CvkhGYtL9XlmzimkcWctiIPtx1/kQy0/U9IhFpH7vz12Zpg+1ltLA8t3Qc760r4/KH5nPAgFzuvWgy2Zma4kdE2k9zS0gkJqqbgPvM7CY+XELi+8CNyQxO2se60p1c/Me59MjO4I8XTyEvOzPqkESkk2nuNl0NH/Z86h8cnNegbDpwb3JCk/awbWc1X5o5lx2xWv56+ScY1DM76pBEpBNqLhmNbLcoJBLVtXVc/lAhyzdv548XT+GgQXlRhyQinVRz6xmtbFgW3robCGx097pkBibJFY/H+cETi/jvB1v45bRDOPKAflGHJCKdWGuHdueZ2YNAJcEUQDvN7AEz65nU6CRp7vu/5Tw8dzVXHLc/50zKjzocEenkWjua7g6gOzAW6AaMI5g89Y4kxSVJ9ELRRm7+VxEnjx3Etz5tUYcjItLq7xmdBIxy9x3h9vtmdjHwQXLCkmRZvKGMKx9ewNghPfnV5/WlVhFJDa1NRpVAfyDxOVI/ILa3AZjZNIKh4wXAFHefF5ZnEozUmxjG+aC735JwXDrBDOJr3f20sGwkwUwRfYD5wIXuXhXWfT58nzjwlrtPD8svAr4XnvYn7v7A3l5TqirZXsWlD8wjNzuDP3xxMjldteq8iKSG1t6muxd4zsy+ZmYnm9nXgGeB37dBDIsIFu97uUH5NCDL3ccRrC57mZmNSKi/imBpi0S3Are5+2hgK3AJgJmNBq4HjnT3g4FvhuV9CL4rdTgwBbjRzHq3wTWlnOraOr7+50KKy2P87sLJGsItIimltcnoZuBnwDkE0/+cA/w8LN8r7l7k7t5IVRzobmYZBM+pqoAyADPLB04l4TtOZpYGTAUeDYseIFiNFuArwF3uvjV8z+Kw/ETgOXcvCeueI7gluc/58T/f4/VlJfzs7HEcOqxX1OGIiHxEi/dpwtthNwI3u/v9yQ9pl0eBM4D1BIMlrnb3krBuBnAtwZpK9foCpe5eE26vAYaGrw8EMLNXgXTgJnd/JqxfnXCOxGOaFIvFKCpq2ClrWmVl5W7t39aefr+MB1/bzNljelLQrTzSWJoSdRulOrVP89Q+LUv1NmoxGbl7rZldQfC8ZY+Y2fPAoEaqbnD3J5o4bArBgn5DgN7AK+F5xgDF7l5oZscm7N/Yk/j62SIygNHAsUB+eK6xLRzTpKysLAoKClrabZeioqLd2r8tFa4s4Z65y/nU6H784oIppKfogIUo26gjUPs0T+3TslRoo8LCwibrWvsE+wHga8DdexKAu+/JYnzTgWfcvRooDns1k4EJwOlmdgqQDeSZ2UPAhUAvM8sIe0f5BOsuQdDjeT0813Izc4LktIYgQdXLB+bsQawpqbiskq89NJ8hvbpx53kTUzYRiYi0NhlNAf7HzK4luK21q/fg7kcnIzBgFTA1TDQ5wBHADHefTTAYgbBn9G13vyDcfongedYs4CKgvtf1d4J59f5oZv0IbtstIxia/tOEQQufqT93R1dVU8flf55PRWUND11yOD1zNPmpiKSu1iajP4Q/bc7MzgJ+QzB0/CkzW+juJwJ3ATMJRtulATPd/e0WTncdMMvMfgIsAO4Ly58FPmNm7xHc+vuOu28J3//HQP2y6j9KeC7Vof3on+9SuHIrd06fgA3q0fIBIiIRam0y+key/ki7++PA442UVxAM727u2Dkk3FZz92UEvbiG+8WBa8KfhnX3A+05MCPpZr+5modeX8Vlx4zitPFDog5HRKRFzQ7tNrMjzGwdsMnMVprZoe0Ul+yhRWu38b0nFnHUAf34zmc01Y+IdAwtfc/ol8CfCOaimx1uS4ratrOar/95Pn27d+X2cw8lQ8uGi0gH0dJfqzHA/7r7ewRT5oxNfkiyJ+rq4nxr9lusK93JndMn0jc3K+qQRERaraVklOHutQDuHgO6Jj8k2RO/f2UZzxdt5IZTC5i03z45o5GI7MNaGsCQHa5jVK97g23c/YttH5bsjrnLS/jFs86p4wfzpU+OiDocEZHd1lIyajj33E+TFYjsmdIdVVw1awHD++Rw6+fGk5amL7aKSMfTbDJy9x+2VyCy++LxON/92ztsrojx2OVHkpulJSFEpGPScKsO7OG5q3nm3Q1850RjXL5WgBeRjkvJqINasrGcH/3zXT41uh+XHjUq6nBERPaKklEHVFVTx5WzFtK9a4aWDheRfYIeMnRAv/3PBxStL+MPX5zMgB5asVVEOr5WJSMz+3ITVTE+XJ4h1mZRSZOWFpdz54tLOW38YD49ZmDU4YiItInW9oy+CHwC2EiQfPKBgcA8YASAmZ3h7vOSEKOE6uriXPe3d8jJSuem0w+OOhwRkTbT2mT0LvCYu99RX2Bm3wAOAo4CbiBYBuITbR6h7PLQGyspXLmVX007hH6a7kdE9iGtHcAwHbizQdk9wPnh8gy/IJjHTpJkbZ/djFsAABF9SURBVOlObn16MZ8a3Y+zJw6NOhwRkTbV2mS0Efhsg7JTgeLwdTZQ3VZBycfd/NR71MXhp2eN0ywLIrLPae1tuiuBv5rZIoJlx4cRzOBdv/jd4QS36SQJ3l5Tyr/e2cBVx49mWJ+cqMMREWlzrUpG7v5vM9sfOBkYAvwLeKp+6W53/zfw76RF2cn94lmnT/euXPqpkVGHIiKSFK3+0qu7bwb+A7wMzKlPRJJcry7dzCtLNnPFcQfQIzsz6nBERJKitd8zGgzMAo4ASoC+ZvY6cK67r0tifJ1aPB7n588sZkjPbM4/fHjU4YiIJE1re0b3AG8Bfdx9MNAbWAD8NlmBCTz77gbeWrONb376QLIz06MOR0QkaVo7gOEoYLC7VwO4+3YzuxZYm7TIOrma2jp+8axzwIBczp6godwism9rbc9oKx//HpEBpW0bjtR7vmgjH2zazjWfPpCMdM1nKyL7ttb2jH4OPG9m9wErgf2Ai4Hvt0UQZjYNuAkoAKbUTytkZpnAvcDEMNYH3f2WhOPSCaYkWuvup4VlIwmeb/UB5gMXunuVmQ0HHgB6AenAd939X+Ex1wOXALXAle7+bFtc196Y9eZqBuVl8xnNPycinUCr/svt7n8AvgD0I/jyaz/gPHf/fRvFsQg4m2CkXqJpQJa7jwMmAZeZ2YiE+quAogbH3Arc5u6jCXp0l4Tl3wNmu/sE4FzgbgAzGxNuHwycBNwdJrnIrCvdycvvb2La5Hz1ikSkU9idod0vuvul7n6Ku18K/MfMftQWQbh7kbt7I1VxoLuZZQDdgCqgDMDM8glmgbi3fmczSwOmAo+GRQ8AZyacKy983ROoHwV4BjDL3WPuvhxYCkxpi+vaU48WrqEuDtMmDYsyDBGRdrM36xllEEyQ+oM2iqUxjxIki/VADnC1u5eEdTOAa4EeCfv3BUrdvSbcXgPUP/2/Cfi3mf0P0B04ISwfCryecI7EYxoVi8UoKmrYIWtaZWVlq/evi8d56L+rOXRwN7YXr6SouOVj9gW700adkdqneWqflqV6G+3t4nqtniTNzJ4HBjVSdYO7P9HEYVMInuMMIRhO/kp4njFAsbsXmtmxLcQTD3+fB/zR3X9lZp8A/mRmY1s4plFZWVkUFBQ0t8tHFBUVtXr/V5Zsonj7cr53+jgKCoa0+j06ut1po85I7dM8tU/LUqGNCgsLm6zb22TU7B/tRO5+Qst7fcx04JlwSHmxmb0KTAYmAKeb2SkEk7TmmdlDwIVALzPLCHtH+Xx4O+4SgmdCuPtrZpZN8OxrDcFce/USj2l3j7y5ml45mRq4ICKdSrPJyMymNlPdtY1jacwqYGqYaHIIZoCY4e6zgesBwp7Rt939gnD7JeAcghF1FwFPJJzreOCPZlZAkMQ2AU8CfzGzXxP0wEYDc9vh2j6mZHsV/353I9MPH64vuYpIp9JSz+i+FupXtUUQZnYWwazf/YGnzGyhu58I3AXMJBhtlwbMdPe3WzjddcAsM/sJwSwR9dfwLeAPZnY1QY/uS+FaTO+a2WzgPaAGuMLda9viunbX4wvWUlVbxxcO08AFEelcmk1G7t4u00S7++PA442UV/DhMhVNHTsHmJOwvYxGRsO5+3vAkU2c42bg5t2JORkem7+GQ/J7UjA4r+WdRUT2IfoSS4rYtqOa99aXcXyBnhWJSOejZJQiCleVEI/DYSP6RB2KiEi7UzJKEXOXbyUzPY1Dh/WKOhQRkXanZJQi5q0oYezQnnTrqlF0ItL5KBmlgMrqWt5es40pukUnIp2UklEKeHvNNqpq65isZCQinZSSUQp4c0Uw3d7k/XpHHImISDSUjFLA3OUlHDgwl97d22NSCxGR1KNkFLHaujjzV27VLToR6dSUjCK2eEMZ5bEaDV4QkU5NyShi81ZsBeCwkUpGItJ5KRlFbO6KEob0zGZor25RhyIiEhklowjF43HeXF6iXpGIdHpKRhFaXbKT4vKYBi+ISKenZBSh+u8XafCCiHR2SkYRWr55O+ld0jhgQG7UoYiIRErJKELF5ZX0y+1Kepe0qEMREYmUklGEistjDOiRHXUYIiKRUzKKUHFZjIF5WVGHISISOSWjCBWXV9JfPSMRESWjqNTU1rFlexUDeqhnJCKiZBSRzRVVxOMwQLfpRESUjKJSXF4JoAEMIiJARtQBmNk04CagAJji7vPC8kzgXmAiQZwPuvstCcelA/OAte5+Wlj2DeCbwP5Af3ffHJanAbcDpwA7gC+5+/yw7iLge+Fpf+LuDyT1gkPFZTEA3aYTESE1ekaLgLOBlxuUTwOy3H0cMAm4zMxGJNRfBRQ1OOZV4ARgZYPyk4HR4c9XgXsAzKwPcCNwODAFuNHM2mW51eLyMBnpNp2ISPTJyN2L3N0bqYoD3c0sA+gGVAFlAGaWD5xK0HNKPNcCd1/RyLnOIOhZxd39daCXmQ0GTgSec/cSd98KPAec1EaX1qzi8krS0qBfrpKRiEjkt+ma8ShBElkP5ABXu3tJWDcDuBbo0cpzDQVWJ2yvCcuaKm9WLBajqKhhp6xplZWVH9v//VWbyMvqwtL3G8vDnU9jbSQfUvs0T+3TslRvo3ZJRmb2PDCokaob3P2JJg6bAtQCQ4DewCvhecYAxe5eaGbHtjKExubbiTdT3qysrCwKCgpa+dZQVFT0sf2r5m5ncC926zz7ssbaSD6k9mme2qdlqdBGhYWFTda1SzJy9xP24LDpwDPuXg0Um9mrwGRgAnC6mZ0CZAN5ZvaQu1/QzLnWAMMStvOBdWH5sQ3K5+xBrLutuLySgXkaSSciAinwzKgZq4CpZpZmZt2BI4DF7n69u+e7+wjgXODFFhIRwJPAF8NzHQFsc/f1wLPAZ8ysdzhw4TNhWdIVl8U0kk5EJBR5MjKzs8xsDfAJ4Ckzq08GdwG5BKPt3gRmuvvbLZzryvBc+cDbZlY/wOFfwDJgKfAH4OsA4TOoH4fnfxP4UcJzqaSpq4uzuSKmkXQiIqHIBzC4++PA442UVxAM727u2Dkk3FZz9zuAOxrZLw5c0cQ57gfu352Y91bJjipq6uL6wquISCjynlFnpC+8ioh8lJJRBHZNBaTbdCIigJJRJHbNvqDbdCIigJJRJDaFyai/btOJiABKRpHYWFZJXnYG2ZnpUYciIpISlIwiUFwWY4C+8CoisouSUQSC2Rd0i05EpJ6SUQSKy2MavCAikkDJqJ3F4/EwGalnJCJST8monZXtrKGqpk4j6UREEigZtbMPv/Cq23QiIvWUjNrZh194Vc9IRKSeklE729UzUjISEdlFyaidbayfJFW36UREdlEyamfFZTFyuqaTmxX56h0iIilDyaidFZdX6hadiEgDSkbtTF94FRH5OCWjdrapXMuNi4g0pGTUzorLKtUzEhFpQMmoHe2sqmV7Va16RiIiDSgZtaPszC5884TRnH7IkKhDERFJKRpf3I7S0tL45gkHRh2GiEjKiTwZmdk04CagAJji7vPC8kzgXmAiQZwPuvstCcelA/OAte5+Wlj2DeCbwP5Af3ffHJafD1wXHloBXO7ub4V1JwG3A+nAve7+s6ResIiIfEwq3KZbBJwNvNygfBqQ5e7jgEnAZWY2IqH+KqCowTGvAicAKxuULweOcffxwI+B38OuhHYXcDIwBjjPzMbs7QWJiMjuibxn5O5FAGbWsCoOdDezDKAbUAWUhfvmA6cCNwPXJJxrQWPncvf/Jmy+DuSHr6cAS919WXjcLOAM4L29vzIREWmtyJNRMx4lSAzrgRzgancvCetmANcCPfbgvJcAT4evhwKrE+rWAIe3dIJYLEZRUcNOWdMqKyt3a//OSG3UPLVP89Q+LUv1NmqXZGRmzwODGqm6wd2faOKwKUAtMAToDbwSnmcMUOzuhWZ27G7GcRxBMjoqLEprZLd4S+fJysqioKCg1e9bVFS0W/t3Rmqj5ql9mqf2aVkqtFFhYWGTde2SjNz9hD04bDrwjLtXA8Vm9iowGZgAnG5mpwDZQJ6ZPeTuFzR3MjMbTzAg4mR33xIWrwGGJeyWD6zbg1hFRGQvpPJtulXAVDN7iOA23RHADHefDVwPEPaMvt2KRDQceAy40N3fT6h6ExhtZiOBtcC5BElQRETaUeSj6czsLDNbA3wCeMrMng2r7gJyCUbbvQnMdPe3WzjXleG58oG3zezesOoHQF/gbjNbaGbzANy9BvgG8CzByLzZ7v5u216hiIi0JC0eb/ERiTRQWFi4iY8PHxcRkebtN2nSpP6NVSgZiYhI5CK/TSciIqJkJCIikVMyEhGRyCkZiYhI5JSMREQkckpGIiISuVSegWGfoPWSPsrMhgEPEsxVWAf83t1vN7M+wCPACGAF8Hl33xpVnFFruF5XOEvILKAPMJ9gNpGqKGOMkpn1IpjeayzBfJJfBhx9hgAws6uBSwna5h3gYmAwKfwZUs8oibReUqNqgG+5ewHBFE9XhG3yXeAFdx8NvBBud2YN1+u6FbgtbJ+tBBP+dma3E8xdeRBwCEFb6TMEmNlQ4EpgsruPJfiP8Lmk+GdIySi5dq2XFP4PpH69pE7L3de7+/zwdTnBH5GhBO3yQLjbA8CZ0UQYvYT1uu4Nt9OAqQTLqoDaJw84GrgPwN2r3L0UfYYSZQDdwvXgcgiW4knpz5CSUXI1tl7S0IhiSTnhyr0TgDeAge6+HoKEBQyIMLSo1a/XVRdu9wVKw7kUQZ+jUcAmYKaZLTCze82sO/oMAeDua4FfEkw2vR7YBhSS4p8hJaPk2qP1kjoDM8sF/gZ8093Loo4nVZjZaYTrdSUU63P0URnAROAed58AbKeT3pJrjJn1JugljiRYD647waOChlLqM6RklFxaL6kRZpZJkIj+7O6PhcUbzWxwWD8YKI4qvogdSbBe1wqC27pTCXpKvcJbLqDP0Rpgjbu/EW4/SpCc9BkKnAAsd/dN4XpwjwGfJMU/Q0pGybVrvSQz60rwEPHJiGOKVPj84z6gyN1/nVD1JHBR+PoioKkVgPdp7n69u+e7+wiCz8uL7n4+8BJwTrhbp20fAHffAKw2MwuLjgfeQ5+hequAI8wsJ/z3Vt8+Kf0Z0qzdSRauSDuDYETL/e5+c8QhRcrMjgJeIRhuWv9M5H8JnhvNBoYT/GOa5u4lkQSZIhIWjzzNzEbx4bDcBcAF7h6LMr4omdmhBAM8ugLLCIYud0GfIQDM7IfAFwhGry4gGOY9lBT+DCkZiYhI5HSbTkREIqdkJCIikVMyEhGRyCkZiYhI5JSMREQkckpGIp2YmcXN7ICo4xDREhIiKSSceWEgUJtQ/Ed3/0YkAYm0EyUjkdTzWXd/PuogRNqTkpFIB2BmXwK+QrAo2hcJZmO+wt1fCOuHAL8FjgJKgFvd/Q9hXTpwHcH6NQOA94Ez3b1+RvkTzOxpoB/wF+Ab7q5vw0u70jMjkY7jcIKpb/oBNwKPhSvkAjxMMIHoEIL5x35qZseHddcA5wGnAHkEq6LuSDjvacBhBIvUfR44MbmXIfJx6hmJpJ6/m1lNwvZ3gGqCWahnhL2WR8zsW8CpZjaHoEd0mrtXAgvN7F7gQoIVTy8FrnV3D8/3VoP3+1m4OF2pmb0EHAo8k6RrE2mUkpFI6jmz4TOj8Dbd2ga3z1YS9ISGACXhyrmJdZPD18OAD5p5vw0Jr3cAuXsYt8ge0206kY5jaLgkQL3hBGvSrAP6mFmPBnVrw9ergf3bJ0SRPaOekUjHMQC40szuBs4ECoB/ufsWM/svcIuZfRs4kGCwwgXhcfcCPzaz94ClwDiCXtaWdr8CkSYoGYmknn+YWeL3jJ4jWAjtDWA0sBnYCJyTkFDOIxhNtw7YCtzo7s+Fdb8GsoB/Ewx+WAycleyLENkdWs9IpAMInxld6u5HRR2LSDLomZGIiEROyUhERCKn23QiIhI59YxERCRySkYiIhI5JSMREYmckpGIiEROyUhERCL3/6yYx0BEbHZNAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(history.epochs, history.log_probabilities)\n",
    "plt.xlabel(\"Epoch\", fontsize=12)\n",
    "plt.ylabel(\"Log Probability\", fontsize=12)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we expected, the log probability of the data set goes up during training, because the model is being explicitly fit to the data set.\n",
    "\n",
    "Now let's look at how to pass in additional callbacks. Let's take a look at the CSV logger. All we have to do is create the CSVLogger object by passing in the name of the file to save to and then pass that object in to the fit function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>epoch</th>\n",
       "      <th>duration</th>\n",
       "      <th>total_improvement</th>\n",
       "      <th>improvement</th>\n",
       "      <th>log_probability</th>\n",
       "      <th>last_log_probability</th>\n",
       "      <th>epoch_start_time</th>\n",
       "      <th>epoch_end_time</th>\n",
       "      <th>n_seen_batches</th>\n",
       "      <th>learning_rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0.002165</td>\n",
       "      <td>6079.062778</td>\n",
       "      <td>6079.062778</td>\n",
       "      <td>-184116.511097</td>\n",
       "      <td>-190195.573874</td>\n",
       "      <td>1.575228e+09</td>\n",
       "      <td>1.575228e+09</td>\n",
       "      <td>None</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>0.002047</td>\n",
       "      <td>6102.391290</td>\n",
       "      <td>23.328513</td>\n",
       "      <td>-184093.182584</td>\n",
       "      <td>-184116.511097</td>\n",
       "      <td>1.575228e+09</td>\n",
       "      <td>1.575228e+09</td>\n",
       "      <td>None</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>0.002156</td>\n",
       "      <td>6112.050803</td>\n",
       "      <td>9.659512</td>\n",
       "      <td>-184083.523072</td>\n",
       "      <td>-184093.182584</td>\n",
       "      <td>1.575228e+09</td>\n",
       "      <td>1.575228e+09</td>\n",
       "      <td>None</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>0.002037</td>\n",
       "      <td>6117.923677</td>\n",
       "      <td>5.872874</td>\n",
       "      <td>-184077.650198</td>\n",
       "      <td>-184083.523072</td>\n",
       "      <td>1.575228e+09</td>\n",
       "      <td>1.575228e+09</td>\n",
       "      <td>None</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>0.002295</td>\n",
       "      <td>6122.236641</td>\n",
       "      <td>4.312964</td>\n",
       "      <td>-184073.337234</td>\n",
       "      <td>-184077.650198</td>\n",
       "      <td>1.575228e+09</td>\n",
       "      <td>1.575228e+09</td>\n",
       "      <td>None</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   epoch  duration  total_improvement  improvement  log_probability  \\\n",
       "0      1  0.002165        6079.062778  6079.062778   -184116.511097   \n",
       "1      2  0.002047        6102.391290    23.328513   -184093.182584   \n",
       "2      3  0.002156        6112.050803     9.659512   -184083.523072   \n",
       "3      4  0.002037        6117.923677     5.872874   -184077.650198   \n",
       "4      5  0.002295        6122.236641     4.312964   -184073.337234   \n",
       "\n",
       "   last_log_probability  epoch_start_time  epoch_end_time n_seen_batches  \\\n",
       "0        -190195.573874      1.575228e+09    1.575228e+09           None   \n",
       "1        -184116.511097      1.575228e+09    1.575228e+09           None   \n",
       "2        -184093.182584      1.575228e+09    1.575228e+09           None   \n",
       "3        -184083.523072      1.575228e+09    1.575228e+09           None   \n",
       "4        -184077.650198      1.575228e+09    1.575228e+09           None   \n",
       "\n",
       "   learning_rate  \n",
       "0            0.0  \n",
       "1            0.0  \n",
       "2            0.0  \n",
       "3            0.0  \n",
       "4            0.0  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas\n",
    "from pomegranate.callbacks import CSVLogger\n",
    "\n",
    "d1 = MultivariateGaussianDistribution(numpy.zeros(13), numpy.eye(13))\n",
    "d2 = MultivariateGaussianDistribution(numpy.ones(13), numpy.eye(13))\n",
    "\n",
    "model = GeneralMixtureModel([d1, d2])\n",
    "model.fit(X, callbacks=[CSVLogger(\"logs.csv\")])\n",
    "\n",
    "logs = pandas.read_csv(\"logs.csv\")\n",
    "logs.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The CSV will now contain the information that the History object stores, but in a convenient written format. Note that some of the columns will correspond to information that isn't particularly useful for normal training, such as \"learning rate.\" While conceptually similar to the learning rate used in training neural networks, EM does not necessarily benefit in the same way that gradient descent does from tuning it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Implementing a custom callback\n",
    "\n",
    "Now let's look at an example of creating a custom callback. This callback will take in a training and a validation set and output both the training and validation set log probabilities. Currently, pomegranate does not allow for a user to pass a validation set in to the fit function and monitor performance that way, so this custom callback is an easy way around that limitation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pomegranate.callbacks import Callback\n",
    "\n",
    "class ValidationSetCallback(Callback):\n",
    "    \"\"\"This callback evaluates a validation set after each epoch.\"\"\"\n",
    "\n",
    "    def __init__(self, X_train, X_valid):\n",
    "        self.X_train = X_train\n",
    "        self.X_valid = X_valid\n",
    "        self.model = None\n",
    "        self.params = None\n",
    "\n",
    "    def on_epoch_end(self, logs):\n",
    "        \"\"\"Functionality to add to the end of each epoch.\n",
    "\n",
    "        This method will be called at the end of each epoch during the model's\n",
    "        iterative training procedure.\n",
    "        \"\"\"\n",
    "    \n",
    "        epoch = logs['epoch']\n",
    "        train_logp = self.model.log_probability(self.X_train).sum()\n",
    "        valid_logp = self.model.log_probability(self.X_valid).sum()\n",
    "        \n",
    "        print(\"Epoch {} -- Training LogP: {:4.4} -- Validation LogP: {:4.4}\".format(epoch, train_logp, valid_logp))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above code seems fairly simple. All we do is store the data sets that are passed in and then calculate their respective log probabilities at the end of each epoch and print that out to the screen. Let's see how it works on a data set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 -- Training LogP: -5.874e+03 -- Validation LogP: -2.949e+03\n",
      "Epoch 2 -- Training LogP: -5.825e+03 -- Validation LogP: -2.93e+03\n",
      "Epoch 3 -- Training LogP: -5.747e+03 -- Validation LogP: -2.895e+03\n",
      "Epoch 4 -- Training LogP: -5.675e+03 -- Validation LogP: -2.859e+03\n",
      "Epoch 5 -- Training LogP: -5.647e+03 -- Validation LogP: -2.845e+03\n",
      "Epoch 6 -- Training LogP: -5.64e+03 -- Validation LogP: -2.842e+03\n",
      "Epoch 7 -- Training LogP: -5.637e+03 -- Validation LogP: -2.842e+03\n",
      "Epoch 8 -- Training LogP: -5.635e+03 -- Validation LogP: -2.842e+03\n",
      "Epoch 9 -- Training LogP: -5.634e+03 -- Validation LogP: -2.843e+03\n",
      "Epoch 10 -- Training LogP: -5.633e+03 -- Validation LogP: -2.843e+03\n",
      "Epoch 11 -- Training LogP: -5.632e+03 -- Validation LogP: -2.844e+03\n",
      "Epoch 12 -- Training LogP: -5.632e+03 -- Validation LogP: -2.844e+03\n",
      "Epoch 13 -- Training LogP: -5.631e+03 -- Validation LogP: -2.844e+03\n",
      "Epoch 14 -- Training LogP: -5.631e+03 -- Validation LogP: -2.844e+03\n",
      "Epoch 15 -- Training LogP: -5.631e+03 -- Validation LogP: -2.845e+03\n",
      "Epoch 16 -- Training LogP: -5.63e+03 -- Validation LogP: -2.845e+03\n",
      "Epoch 17 -- Training LogP: -5.63e+03 -- Validation LogP: -2.845e+03\n",
      "Epoch 18 -- Training LogP: -5.63e+03 -- Validation LogP: -2.845e+03\n"
     ]
    }
   ],
   "source": [
    "numpy.random.seed(0)\n",
    "\n",
    "X_train = numpy.concatenate([\n",
    "    numpy.random.normal(0, 1.0, size=(500, 5)),\n",
    "    numpy.random.normal(0.3, 0.8, size=(500, 5)),\n",
    "    numpy.random.normal(-0.3, 0.4, size=(500, 5))\n",
    "])\n",
    "\n",
    "idx = numpy.arange(X_train.shape[0])\n",
    "numpy.random.shuffle(idx)\n",
    "X_train = X_train[idx]\n",
    "\n",
    "X_valid = X_train[:500]\n",
    "X_train = X_train[500:]\n",
    "\n",
    "\n",
    "callback = ValidationSetCallback(X_train, X_valid)\n",
    "\n",
    "d1 = MultivariateGaussianDistribution(numpy.zeros(5), numpy.eye(5))\n",
    "d2 = MultivariateGaussianDistribution(numpy.ones(5), numpy.eye(5))\n",
    "d3 = MultivariateGaussianDistribution(-numpy.ones(5), numpy.eye(5))\n",
    "\n",
    "model = GeneralMixtureModel([d1, d2, d3])\n",
    "_ = model.fit(X_train, callbacks=[callback])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Simple! All we did was create the object and then pass it in to the callbacks parameter during fitting. We can see that the validation log probability initially goes down with the training set log probability, but then goes up a little bit, showing that the model is beginning to overfit a tiny bit to the data set."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
